{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multivariate time series classification with sktime\n",
    "\n",
    "In this notebook, we will use sktime for multivariate time series classification.\n",
    "\n",
    "For the simpler univariate time series classification setting, take a look at this [notebook](https://github.com/alan-turing-institute/sktime/blob/main/examples/02_classification_univariate.ipynb).\n",
    "\n",
    "### Preliminaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-04-09T13:56:48.609970Z",
     "iopub.status.busy": "2021-04-09T13:56:48.609298Z",
     "iopub.status.idle": "2021-04-09T13:56:49.811948Z",
     "shell.execute_reply": "2021-04-09T13:56:49.812468Z"
    },
    "pycharm": {
     "is_executing": false
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.pipeline import Pipeline\n",
    "\n",
    "from sktime.classification.compose import ColumnEnsembleClassifier\n",
    "from sktime.classification.dictionary_based import BOSSEnsemble\n",
    "from sktime.classification.interval_based import TimeSeriesForestClassifier\n",
    "from sktime.classification.shapelet_based import MrSEQLClassifier\n",
    "from sktime.datasets import load_basic_motions\n",
    "from sktime.transformations.panel.compose import ColumnConcatenator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load multivariate time series/panel data\n",
    "\n",
    "The [data set](http://www.timeseriesclassification.com/description.php?Dataset=BasicMotions) we use in this notebook was generated as part of a student project where four students performed four activities whilst wearing a smart watch. The watch collects 3D accelerometer and a 3D gyroscope It consists of four classes, which are walking, resting, running and badminton. Participants were required to record motion a total of five times, and the data is sampled once every tenth of a second, for a ten second period."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-04-09T13:56:49.816449Z",
     "iopub.status.busy": "2021-04-09T13:56:49.815969Z",
     "iopub.status.idle": "2021-04-09T13:56:49.885003Z",
     "shell.execute_reply": "2021-04-09T13:56:49.885481Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(60, 6) (60,) (20, 6) (20,)\n"
     ]
    }
   ],
   "source": [
    "X, y = load_basic_motions(return_X_y=True)\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\n",
    "print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-04-09T13:56:49.913836Z",
     "iopub.status.busy": "2021-04-09T13:56:49.913366Z",
     "iopub.status.idle": "2021-04-09T13:56:49.937176Z",
     "shell.execute_reply": "2021-04-09T13:56:49.937747Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dim_0</th>\n",
       "      <th>dim_1</th>\n",
       "      <th>dim_2</th>\n",
       "      <th>dim_3</th>\n",
       "      <th>dim_4</th>\n",
       "      <th>dim_5</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0    -0.407421\n",
       "1    -0.407421\n",
       "2     2.355158\n",
       "3...</td>\n",
       "      <td>0     1.413374\n",
       "1     1.413374\n",
       "2    -3.928032\n",
       "3...</td>\n",
       "      <td>0     0.092782\n",
       "1     0.092782\n",
       "2    -0.211622\n",
       "3...</td>\n",
       "      <td>0    -0.066584\n",
       "1    -0.066584\n",
       "2    -3.630177\n",
       "3...</td>\n",
       "      <td>0     0.223723\n",
       "1     0.223723\n",
       "2    -0.026634\n",
       "3...</td>\n",
       "      <td>0     0.135832\n",
       "1     0.135832\n",
       "2    -1.946925\n",
       "3...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>0     0.383922\n",
       "1     0.383922\n",
       "2    -0.272575\n",
       "3...</td>\n",
       "      <td>0     0.302612\n",
       "1     0.302612\n",
       "2    -1.381236\n",
       "3...</td>\n",
       "      <td>0    -0.398075\n",
       "1    -0.398075\n",
       "2    -0.681258\n",
       "3...</td>\n",
       "      <td>0     0.071911\n",
       "1     0.071911\n",
       "2    -0.761725\n",
       "3...</td>\n",
       "      <td>0     0.175783\n",
       "1     0.175783\n",
       "2    -0.114525\n",
       "3...</td>\n",
       "      <td>0    -0.087891\n",
       "1    -0.087891\n",
       "2    -0.503377\n",
       "3...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0    -0.357300\n",
       "1    -0.357300\n",
       "2    -0.005055\n",
       "3...</td>\n",
       "      <td>0    -0.584885\n",
       "1    -0.584885\n",
       "2     0.295037\n",
       "3...</td>\n",
       "      <td>0    -0.792751\n",
       "1    -0.792751\n",
       "2     0.213664\n",
       "3...</td>\n",
       "      <td>0     0.074574\n",
       "1     0.074574\n",
       "2    -0.157139\n",
       "3...</td>\n",
       "      <td>0     0.159802\n",
       "1     0.159802\n",
       "2    -0.306288\n",
       "3...</td>\n",
       "      <td>0     0.023970\n",
       "1     0.023970\n",
       "2     1.230478\n",
       "3...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0    -0.352746\n",
       "1    -0.352746\n",
       "2    -1.354561\n",
       "3...</td>\n",
       "      <td>0     0.316845\n",
       "1     0.316845\n",
       "2     0.490525\n",
       "3...</td>\n",
       "      <td>0    -0.473779\n",
       "1    -0.473779\n",
       "2     1.454261\n",
       "3...</td>\n",
       "      <td>0    -0.327595\n",
       "1    -0.327595\n",
       "2    -0.269001\n",
       "3...</td>\n",
       "      <td>0     0.106535\n",
       "1     0.106535\n",
       "2     0.021307\n",
       "3...</td>\n",
       "      <td>0     0.197090\n",
       "1     0.197090\n",
       "2     0.460763\n",
       "3...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>0      0.052231\n",
       "1      0.052231\n",
       "2     -0.54804...</td>\n",
       "      <td>0     -0.730486\n",
       "1     -0.730486\n",
       "2      0.70700...</td>\n",
       "      <td>0    -0.518104\n",
       "1    -0.518104\n",
       "2    -1.179430\n",
       "3...</td>\n",
       "      <td>0    -0.159802\n",
       "1    -0.159802\n",
       "2    -0.239704\n",
       "3...</td>\n",
       "      <td>0    -0.045277\n",
       "1    -0.045277\n",
       "2     0.023970\n",
       "3...</td>\n",
       "      <td>0     -0.029297\n",
       "1     -0.029297\n",
       "2      0.29829...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                dim_0  \\\n",
       "9   0    -0.407421\n",
       "1    -0.407421\n",
       "2     2.355158\n",
       "3...   \n",
       "24  0     0.383922\n",
       "1     0.383922\n",
       "2    -0.272575\n",
       "3...   \n",
       "5   0    -0.357300\n",
       "1    -0.357300\n",
       "2    -0.005055\n",
       "3...   \n",
       "7   0    -0.352746\n",
       "1    -0.352746\n",
       "2    -1.354561\n",
       "3...   \n",
       "34  0      0.052231\n",
       "1      0.052231\n",
       "2     -0.54804...   \n",
       "\n",
       "                                                dim_1  \\\n",
       "9   0     1.413374\n",
       "1     1.413374\n",
       "2    -3.928032\n",
       "3...   \n",
       "24  0     0.302612\n",
       "1     0.302612\n",
       "2    -1.381236\n",
       "3...   \n",
       "5   0    -0.584885\n",
       "1    -0.584885\n",
       "2     0.295037\n",
       "3...   \n",
       "7   0     0.316845\n",
       "1     0.316845\n",
       "2     0.490525\n",
       "3...   \n",
       "34  0     -0.730486\n",
       "1     -0.730486\n",
       "2      0.70700...   \n",
       "\n",
       "                                                dim_2  \\\n",
       "9   0     0.092782\n",
       "1     0.092782\n",
       "2    -0.211622\n",
       "3...   \n",
       "24  0    -0.398075\n",
       "1    -0.398075\n",
       "2    -0.681258\n",
       "3...   \n",
       "5   0    -0.792751\n",
       "1    -0.792751\n",
       "2     0.213664\n",
       "3...   \n",
       "7   0    -0.473779\n",
       "1    -0.473779\n",
       "2     1.454261\n",
       "3...   \n",
       "34  0    -0.518104\n",
       "1    -0.518104\n",
       "2    -1.179430\n",
       "3...   \n",
       "\n",
       "                                                dim_3  \\\n",
       "9   0    -0.066584\n",
       "1    -0.066584\n",
       "2    -3.630177\n",
       "3...   \n",
       "24  0     0.071911\n",
       "1     0.071911\n",
       "2    -0.761725\n",
       "3...   \n",
       "5   0     0.074574\n",
       "1     0.074574\n",
       "2    -0.157139\n",
       "3...   \n",
       "7   0    -0.327595\n",
       "1    -0.327595\n",
       "2    -0.269001\n",
       "3...   \n",
       "34  0    -0.159802\n",
       "1    -0.159802\n",
       "2    -0.239704\n",
       "3...   \n",
       "\n",
       "                                                dim_4  \\\n",
       "9   0     0.223723\n",
       "1     0.223723\n",
       "2    -0.026634\n",
       "3...   \n",
       "24  0     0.175783\n",
       "1     0.175783\n",
       "2    -0.114525\n",
       "3...   \n",
       "5   0     0.159802\n",
       "1     0.159802\n",
       "2    -0.306288\n",
       "3...   \n",
       "7   0     0.106535\n",
       "1     0.106535\n",
       "2     0.021307\n",
       "3...   \n",
       "34  0    -0.045277\n",
       "1    -0.045277\n",
       "2     0.023970\n",
       "3...   \n",
       "\n",
       "                                                dim_5  \n",
       "9   0     0.135832\n",
       "1     0.135832\n",
       "2    -1.946925\n",
       "3...  \n",
       "24  0    -0.087891\n",
       "1    -0.087891\n",
       "2    -0.503377\n",
       "3...  \n",
       "5   0     0.023970\n",
       "1     0.023970\n",
       "2     1.230478\n",
       "3...  \n",
       "7   0     0.197090\n",
       "1     0.197090\n",
       "2     0.460763\n",
       "3...  \n",
       "34  0     -0.029297\n",
       "1     -0.029297\n",
       "2      0.29829...  "
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#  multivariate input data\n",
    "X_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-04-09T13:56:49.942463Z",
     "iopub.status.busy": "2021-04-09T13:56:49.941766Z",
     "iopub.status.idle": "2021-04-09T13:56:49.944096Z",
     "shell.execute_reply": "2021-04-09T13:56:49.944449Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['badminton', 'running', 'standing', 'walking'], dtype=object)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# multi-class target variable\n",
    "np.unique(y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multivariate classification\n",
    "sktime offers three main ways of solving multivariate time series classification problems:\n",
    "\n",
    "1. _Concatenation_ of time series columns into a single long time series column via `ColumnConcatenator` and apply a classifier to the concatenated data,\n",
    "2. _Column-wise ensembling_ via `ColumnEnsembleClassifier` in which one classifier is fitted for each time series column and their predictions aggregated,\n",
    "3. _Bespoke estimator-specific methods_ for handling multivariate time series data, e.g. finding shapelets in multidimensional spaces (still work in progress).\n",
    "\n",
    "### Time series concatenation\n",
    "We can concatenate multivariate time series/panel data into long univariate time series/panel and then apply a classifier to the univariate data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-04-09T13:56:49.948181Z",
     "iopub.status.busy": "2021-04-09T13:56:49.947690Z",
     "iopub.status.idle": "2021-04-09T13:56:50.764554Z",
     "shell.execute_reply": "2021-04-09T13:56:50.765044Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "steps = [\n",
    "    (\"concatenate\", ColumnConcatenator()),\n",
    "    (\"classify\", TimeSeriesForestClassifier(n_estimators=100)),\n",
    "]\n",
    "clf = Pipeline(steps)\n",
    "clf.fit(X_train, y_train)\n",
    "clf.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Column ensembling\n",
    "We can also fit one classifier for each time series column and then aggregated their predictions. The interface is similar to the familiar `ColumnTransformer` from sklearn."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-04-09T13:56:50.773607Z",
     "iopub.status.busy": "2021-04-09T13:56:50.773093Z",
     "iopub.status.idle": "2021-04-09T13:57:00.972349Z",
     "shell.execute_reply": "2021-04-09T13:57:00.972821Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf = ColumnEnsembleClassifier(\n",
    "    estimators=[\n",
    "        (\"TSF0\", TimeSeriesForestClassifier(n_estimators=100), [0]),\n",
    "        (\"BOSSEnsemble3\", BOSSEnsemble(max_ensemble_size=5), [3]),\n",
    "    ]\n",
    ")\n",
    "clf.fit(X_train, y_train)\n",
    "clf.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bespoke classification algorithms\n",
    "Another approach is to use bespoke (or classifier-specific) methods for multivariate time series data. Here, we try out the MrSEQL algorithm  in multidimensional space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-04-09T13:57:00.984985Z",
     "iopub.status.busy": "2021-04-09T13:57:00.984416Z",
     "iopub.status.idle": "2021-04-09T13:57:09.377483Z",
     "shell.execute_reply": "2021-04-09T13:57:09.378081Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf = MrSEQLClassifier()\n",
    "clf.fit(X_train, y_train)\n",
    "clf.score(X_test, y_test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
